from typing import Tuple

import torch


def initialize_parameters(
    d_ext: int, prior_var: float, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Initialize parameters for Thompson sampling.

    Args:
        d_ext (int): Extrinsic dimensionality of the data.
        prior_var (float): Variance of the prior distribution.
        device (torch.device): Device to use for tensor operations.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        Mean, covariance, and ground truth theta.
    """
    mean = torch.zeros(d_ext, device=device)
    cov = torch.eye(d_ext, device=device) * prior_var
    theta_gt = torch.distributions.MultivariateNormal(mean, cov).sample()
    return mean, cov, theta_gt


def update_posterior(
    mean: torch.Tensor,
    cov: torch.Tensor,
    x_new: torch.Tensor,
    y_new: torch.Tensor,
    noise_var: float,
    device: torch.device,
    jitter: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Update the posterior distribution based on new observations using Bayesian linear regression.
    Args:
        mean (torch.Tensor): Current mean of the posterior (D,).
        cov (torch.Tensor): Current covariance of the posterior (D, D).
        x_new (torch.Tensor): New samples X (N, D).
        y_new (torch.Tensor): New observations y (N,).
        noise_var (float): Variance of the observation noise.
        device (torch.device): Device to use for tensor operations.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Updated mean and covariance of the posterior.
    """
    x_new = x_new.to(device)  # (N, D)
    y_new = y_new.to(device).view(-1)
    S = cov
    m = mean.unsqueeze(1)  # Prior mean (D, 1)
    x_new_T = x_new.t()  # (D, N)
    SX = S @ x_new_T  # (D, N)

    denom = x_new @ SX + noise_var * torch.eye(x_new.shape[0], device=device)  # (N, N)

    # Add jitter to ensure positive definiteness and symmetry
    denom = (denom + denom.t()) / 2  # Ensure denom is symmetric
    jitter_matrix = torch.eye(x_new.shape[0], device=device) * jitter
    denom += jitter_matrix  # Add jitter for numerical stability

    # Try Cholesky decomposition with a catch for potential errors
    try:
        denom_chol = torch.linalg.cholesky(denom)
    except torch.linalg.LinAlgError:
        # If Cholesky fails, use symmetric eigendecomposition to handle it
        eigvals, eigvecs = torch.linalg.eigh(denom)
        # Clamp small eigenvalues to ensure positive definiteness
        eigvals = torch.clamp(eigvals, min=jitter)
        denom_chol = eigvecs @ torch.diag(torch.sqrt(eigvals)) @ eigvecs.t()

    denom_inv = torch.cholesky_solve(
        torch.eye(x_new.shape[0], device=device), denom_chol
    )  # (N, N)
    K = SX @ denom_inv  # (D, N)
    residual = y_new.view(-1, 1) - x_new @ m  # (N, 1)
    new_mean = m + K @ residual  # (D, 1)
    new_cov = S - K @ SX.t()  # (D, D)

    # Ensure symmetry of covariance matrix
    new_cov = (new_cov + new_cov.t()) / 2
    return new_mean.squeeze(), new_cov
